{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "authorship_tag": "ABX9TyMkCp2DFyJCa1E9F3s7+Iz5",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/DanielWarfield1/MLWritingAndResearch/blob/main/RAGFromScratch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# RAG From Scratch\n",
        "This notebook is a low level conceptual exploration of RAG. We use a word vector encoder to embed words, calculate the mean vector of documents and prompts, and use manhattan distance as a distance metric.\n",
        "\n",
        "There are surely more efficient/better ways to get this done, which I'll explore in future demos. For now, this is the low level fundamentals.\n",
        "\n",
        "note:The terms \"embedding\" and \"encoding\" are painfully interchangable. Generally encoding is a verb, and an embedding is a noun, so you \"encode words into an embedding\", but it's also common to say you \"embed words into an embedding\". I have a tendency to flip between the two depending on the context."
      ],
      "metadata": {
        "id": "_cTk6k--hFjO"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Loading Word Space Encoder\n"
      ],
      "metadata": {
        "id": "N0Qaw64ejYdh"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Xj9HiosyhBPu",
        "outputId": "a3c330ba-cf4d-4589-fc7b-23f3f470bac8"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([ 0.85337  ,  0.011645 , -0.033377 , -0.31981  ,  0.26126  ,\n",
              "        0.16059  ,  0.010724 , -0.15542  ,  0.75044  ,  0.10688  ,\n",
              "        1.9249   , -0.45915  , -3.3887   , -1.2152   , -0.054263 ,\n",
              "       -0.20555  ,  0.54706  ,  0.4371   ,  0.25194  ,  0.0086557,\n",
              "       -0.56612  , -1.1762   ,  0.010479 , -0.55316  , -0.15816  ],\n",
              "      dtype=float32)"
            ]
          },
          "metadata": {},
          "execution_count": 3
        }
      ],
      "source": [
        "\"\"\"Downloading a word encoder.\n",
        "I was going to use word2vect, but glove downloads way faster. For our purposes\n",
        "they're conceptually identical\n",
        "\"\"\"\n",
        "\n",
        "import gensim.downloader\n",
        "\n",
        "#doenloading encoder\n",
        "word_encoder = gensim.downloader.load('glove-twitter-25')\n",
        "\n",
        "#getting the embedding for a word\n",
        "word_encoder['apple']"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Embedding text\n",
        "embed either the document or the prompt via calculating the mean vector"
      ],
      "metadata": {
        "id": "8i5XlcKAFmqR"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\"\"\"defining a function for embedding an entire document to a single mean vector\n",
        "\"\"\"\n",
        "\n",
        "import numpy as np\n",
        "\n",
        "def embed_sequence(sequence):\n",
        "    vects = word_encoder[sequence.split(' ')]\n",
        "    return np.mean(vects, axis=0)\n",
        "\n",
        "embed_sequence('its a sunny day today')"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "90Ri00AvA-t9",
        "outputId": "05e1e4d8-1a84-4a92-9ed9-2af9c6eb4a2d"
      },
      "execution_count": 23,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([-6.3483393e-01,  1.3683620e-01,  2.0645106e-01, -2.1831200e-01,\n",
              "       -1.8181981e-01,  2.6023200e-01,  1.3276964e+00,  1.7272198e-01,\n",
              "       -2.7881199e-01, -4.2115799e-01, -4.7215199e-01, -5.3013992e-02,\n",
              "       -4.6326599e+00,  4.3883198e-01,  3.6487383e-01, -3.6672002e-01,\n",
              "       -2.6924044e-03, -3.0394283e-01, -5.5415201e-01, -9.1787003e-02,\n",
              "       -4.4997922e-01, -1.4819117e-01,  1.0654800e-01,  3.7024397e-01,\n",
              "       -4.6688594e-02], dtype=float32)"
            ]
          },
          "metadata": {},
          "execution_count": 23
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Defining distance calculation"
      ],
      "metadata": {
        "id": "D-VuWZizI9BU"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from scipy.spatial.distance import cdist\n",
        "\n",
        "def calc_distance(embedding1, embedding2):\n",
        "    return cdist(np.expand_dims(embedding1, axis=0), np.expand_dims(embedding2, axis=0), metric='cityblock')[0][0]\n",
        "\n",
        "print('similar phrases:')\n",
        "print(calc_distance(embed_sequence('sunny day today')\n",
        "                  , embed_sequence('rainy morning presently')))\n",
        "\n",
        "print('different phrases:')\n",
        "print(calc_distance(embed_sequence('sunny day today')\n",
        "                  , embed_sequence('perhaps reality is painful')))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "X0CDcKq2HJRv",
        "outputId": "a9452ee7-cb98-41dd-af84-2b2397a39f25"
      },
      "execution_count": 56,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "similar phrases:\n",
            "8.496297497302294\n",
            "different phrases:\n",
            "11.832107525318861\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Defining Documents"
      ],
      "metadata": {
        "id": "5RPCor0TKRie"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\"\"\"Defining documents\n",
        "for simplicities sake I only included words the embedder knows. You could just\n",
        "parse out all the words the embedder doesn't know, though. After all, the retreival\n",
        "is done on a mean of all embeddings, so a missing word or two is of little consequence\n",
        "\"\"\"\n",
        "documents = {\"menu\": \"ratatouille is a stew thats twelve dollars and fifty cents also gazpacho is a salad thats thirteen dollars and ninety eight cents also hummus is a dip thats eight dollars and seventy five cents also meat sauce is a pasta dish thats twelve dollars also penne marinera is a pasta dish thats eleven dollars also shrimp and linguini is a pasta dish thats fifteen dollars\",\n",
        "             \"events\": \"on thursday we have karaoke and on tuesdays we have trivia\",\n",
        "             \"allergins\": \"the only item on the menu common allergen is hummus which contain pine nuts\",\n",
        "             \"info\": \"the resteraunt was founded by two brothers in two thousand and three\"}"
      ],
      "metadata": {
        "id": "Q-xcLiI4JtXY"
      },
      "execution_count": 57,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Defining Retreival"
      ],
      "metadata": {
        "id": "R58tdc2sSKNB"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\"\"\"defining a function that retreives the most relevent document\n",
        "\"\"\"\n",
        "\n",
        "def retreive_relevent(prompt, documents=documents):\n",
        "    min_dist = 1000000000\n",
        "    r_docname = \"\"\n",
        "    r_doc = \"\"\n",
        "\n",
        "    for docname, doc in documents.items():\n",
        "        dist = calc_distance(embed_sequence(prompt)\n",
        "                           , embed_sequence(doc))\n",
        "\n",
        "        if dist < min_dist:\n",
        "            min_dist = dist\n",
        "            r_docname = docname\n",
        "            r_doc = doc\n",
        "\n",
        "    return r_docname, r_doc\n",
        "\n",
        "\n",
        "prompt = 'what pasta dishes do you have'\n",
        "print(f'finding relevent doc for \"{prompt}\"')\n",
        "print(retreive_relevent(prompt))\n",
        "print('----')\n",
        "prompt = 'what events do you guys do'\n",
        "print(f'finding relevent doc for \"{prompt}\"')\n",
        "print(retreive_relevent(prompt))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Ibs6GhhwRric",
        "outputId": "1b4eec1b-dade-456e-e672-06b4ae6e024e"
      },
      "execution_count": 65,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "finding relevent doc for \"what pasta dishes do you have\"\n",
            "('menu', 'ratatouille is a stew thats twelve dollars and fifty cents also gazpacho is a salad thats thirteen dollars and ninety eight cents also hummus is a dip thats eight dollars and seventy five cents also meat sauce is a pasta dish thats twelve dollars also penne marinera is a pasta dish thats eleven dollars also shrimp and linguini is a pasta dish thats fifteen dollars')\n",
            "----\n",
            "finding relevent doc for \"what events do you guys do\"\n",
            "('events', 'on thursday we do karaoke and on tuesdays we do trivia')\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Defining Retreival and Augmentation"
      ],
      "metadata": {
        "id": "EalA9JmkZ4xg"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\"\"\"Defining retreival and augmentation\n",
        "creating a function that does retreival and augmentation,\n",
        "this can be passed straight to the model\n",
        "\"\"\"\n",
        "def retreive_and_agument(prompt, documents=documents):\n",
        "    docname, doc = retreive_relevent(prompt, documents)\n",
        "    return f\"Answer the customers prompt based on the folowing documents:\\n==== document: {docname} ====\\n{doc}\\n====\\n\\nprompt: {prompt}\\nresponse:\"\n",
        "\n",
        "prompt = 'what events do you guys do'\n",
        "print(f'prompt for \"{prompt}\":\\n')\n",
        "print(retreive_and_agument(prompt))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "iWXhRnTzZ4TB",
        "outputId": "70cad5a3-e5ed-4dd7-fa77-485158dc7f65"
      },
      "execution_count": 78,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "prompt for \"what events do you guys do\":\n",
            "\n",
            "Answer the customers prompt based on the folowing documents:\n",
            "==== document: events ====\n",
            "on thursday we do karaoke and on tuesdays we do trivia\n",
            "====\n",
            "\n",
            "prompt: what events do you guys do\n",
            "response:\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Defining RAG and prompting OpenAI's LLM"
      ],
      "metadata": {
        "id": "VN3Czie_Vvl-"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install --upgrade openai"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "eaPOsnPKZGNX",
        "outputId": "23b2e1d4-82c1-4503-b73b-bcd9ed59eaf1"
      },
      "execution_count": 67,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting openai\n",
            "  Downloading openai-0.28.1-py3-none-any.whl (76 kB)\n",
            "\u001b[?25l     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/77.0 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K     \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━\u001b[0m \u001b[32m71.7/77.0 kB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.0/77.0 kB\u001b[0m \u001b[31m1.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: requests>=2.20 in /usr/local/lib/python3.10/dist-packages (from openai) (2.31.0)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from openai) (4.66.1)\n",
            "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from openai) (3.8.5)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.20->openai) (3.3.0)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.20->openai) (3.4)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.20->openai) (2.0.6)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.20->openai) (2023.7.22)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai) (23.1.0)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai) (6.0.4)\n",
            "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai) (4.0.3)\n",
            "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai) (1.9.2)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai) (1.4.0)\n",
            "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai) (1.3.1)\n",
            "Installing collected packages: openai\n",
            "Successfully installed openai-0.28.1\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#copying from google drive to local\n",
        "from google.colab import drive\n",
        "import os\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "with open (\"/content/drive/My Drive/Colab Notebooks/Credentials/OpenAI-danielDemoKey.txt\", \"r\") as myfile:\n",
        "    OPENAI_API_TOKEN = myfile.read()\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Ozb56SQ8YiJ0",
        "outputId": "b89c20b2-ba0f-4e13-af6e-5d3e54d06cb5"
      },
      "execution_count": 68,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "\"\"\"Using RAG with OpenAI's gpt model\n",
        "\"\"\"\n",
        "\n",
        "import openai\n",
        "openai.api_key = OPENAI_API_TOKEN\n",
        "\n",
        "prompts = ['what pasta dishes do you have', 'what events do you guys do', 'oh cool what is karaoke']\n",
        "\n",
        "for prompt in prompts:\n",
        "\n",
        "    ra_prompt = retreive_and_agument(prompt)\n",
        "    response = openai.Completion.create(model=\"gpt-3.5-turbo-instruct\", prompt=ra_prompt, max_tokens=80).choices[0].text\n",
        "\n",
        "    print(f'prompt: \"{prompt}\"')\n",
        "    print(f'response: {response}')"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "OjQ3nLgyZErb",
        "outputId": "151f6809-8f66-4a52-c763-1a22de4546be"
      },
      "execution_count": 89,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "prompt: \"what pasta dishes do you have\"\n",
            "response:  We have a variety of pasta dishes including meat sauce for $12, penne marinera for $11, and shrimp and linguini for $15.\n",
            "prompt: \"what events do you guys do\"\n",
            "response:  On Thursdays, we do karaoke and on Tuesdays, we do trivia.\n",
            "prompt: \"oh cool what is karaoke\"\n",
            "response:  Karaoke is a fun event where people can sing along to their favorite songs while the lyrics are displayed on a screen. Our karaoke night is held on Thursdays, so make sure to come and join us for a great time!\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "3jBVZDXMeNfu"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}